import tensorflow as tf
import numpy as np

from config.caernn_config import CAERNNCongfig


def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False):
    shape = input_.get_shape().as_list()

    with tf.variable_scope(scope or "Linear"):
        matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32,
                                 tf.random_normal_initializer(stddev=stddev))
        bias = tf.get_variable("bias", [output_size],
                               initializer=tf.constant_initializer(bias_start))
        if with_w:
            return tf.matmul(input_, matrix) + bias, matrix, bias
        else:
            return tf.matmul(input_, matrix) + bias


class AutoEncoderRNNCell(tf.contrib.rnn.RNNCell):
    """Variational RNN cell."""

    def __init__(self, config):
        self.config = config
        x_dim = self.config.Arch.CAERNN.x_dim
        y_dim = self.config.Arch.CAERNN.y_dim
        h_dim = self.config.Arch.CAERNN.hidden_dim
        z_dim = self.config.Arch.CAERNN.latent_dim
        self.n_h = h_dim
        self.n_x = x_dim
        self.n_y = y_dim
        self.n_z = z_dim
        embed_dim = self.n_z
        self.lstm = tf.nn.rnn_cell.LSTMCell(num_units=self.n_h, state_is_tuple=True)
        self.output_dim_list = [self.n_x, embed_dim]

    @property
    def state_size(self):
        return (self.n_h, self.n_h)

    @property
    def output_size(self):
        # enc_mu, enc_sigma, dec_mu, dec_sigma, dec_x, prior_mu, prior_sigma
        return sum(self.output_dim_list)
        # return self.n_h

    def __call__(self, input, state, scope="caernn", inherit_upper_post=False):
        with tf.variable_scope(scope or type(self).__name__):
            c, m = state

            [x, y] = tf.split(input, [self.n_x, self.n_y], axis=1)
            with tf.variable_scope("Encoder"):
                x_phi = tf.nn.relu(linear(x, self.n_h, scope='Linear_x_phi'))
                y_phi = tf.nn.relu(linear(y, self.n_h, scope='Linear_y_phi'))
                y_m = tf.nn.relu(linear(tf.concat(values=[y_phi, x_phi, m], axis=1), self.n_h, scope='Linear_ym'))
                z = linear(y_m, self.n_z, scope='Linear_z')
            with tf.variable_scope("decoder_output"):
                d_z = tf.nn.relu(linear(z, self.n_h, scope='Linear_decoder'))
                recon_output = linear(d_z, self.n_x, scope='Linear_recon')

            with tf.variable_scope("hidden_state"):
                zyx_phi = tf.nn.relu(linear(tf.concat(values=(z, y, x), axis=1), self.n_h, scope='Linear_hidden'))
                output, state2 = self.lstm(zyx_phi, state)

        player_encoding = z

        cell_output = tf.concat(values=(player_encoding, recon_output), axis=1)
        return cell_output, state2


class CAERNN():
    def __init__(self, config, extra_prediction_flag=False):
        self.extra_prediction_flag = extra_prediction_flag
        self.win_score_diff = True
        self.predict_action_goal = True
        self.config = config
        self.target_data_ph = tf.placeholder(dtype=tf.float32,
                                             shape=[None, self.config.Learn.max_seq_length,
                                                    self.config.Arch.CAERNN.x_dim], name='target_data')
        self.input_data_ph = tf.placeholder(dtype=tf.float32,
                                            shape=[None, self.config.Learn.max_seq_length,
                                                   self.config.Arch.CAERNN.x_dim + self.config.Arch.CAERNN.y_dim],
                                            name='input_data')

        self.selection_matrix_ph = tf.placeholder(dtype=tf.int32,
                                                  shape=[None, self.config.Learn.max_seq_length],
                                                  name='selection_matrix')
        # self.sarsa_target_ph = tf.placeholder(dtype=tf.float32,
        #                                       shape=[None, 3], name='sarsa_target')

        self.trace_length_ph = tf.placeholder(dtype=tf.int32, shape=[None], name='trace_length')

        self.score_diff_target_ph = tf.placeholder(dtype=tf.float32,
                                                   shape=[None, 3], name='win_target')

        self.action_pred_target_ph = tf.placeholder(dtype=tf.float32,
                                                    shape=[None, self.config.Arch.Predict.output_size],
                                                    name='action_predict')

        embed_dim = self.config.Arch.CAERNN.latent_dim

        self.cell_output_dim_list = [self.config.Arch.CAERNN.x_dim, embed_dim]
        self.cell_output_names = ["output", "embedding"]

        self.score_diff_lstm_cell = []
        self.action_lstm_cell = []
        self.build_lstm_models()

    def build_lstm_models(self):

        with tf.name_scope("win"):
            with tf.name_scope("LSTM-layer"):
                for i in range(self.config.Arch.WIN.lstm_layer_num):
                    self.score_diff_lstm_cell.append(
                        tf.nn.rnn_cell.LSTMCell(num_units=self.config.Arch.WIN.h_size, state_is_tuple=True,
                                                initializer=tf.random_uniform_initializer(-0.05, 0.05)))

        with tf.name_scope("prediction"):
            with tf.name_scope("LSTM-layer"):
                for i in range(self.config.Arch.Predict.lstm_layer_num):
                    self.action_lstm_cell.append(
                        tf.nn.rnn_cell.LSTMCell(num_units=self.config.Arch.Predict.h_size, state_is_tuple=True,
                                                initializer=tf.random_uniform_initializer(-0.05, 0.05)))

    # @property
    def __call__(self):
        def tf_cross_entropy(ce_output, ce_target, condition, if_last_output):
            with tf.variable_scope('win_cross_entropy'):
                ce_loss_all = tf.losses.softmax_cross_entropy(onehot_labels=ce_target,
                                                              logits=ce_output, reduction=tf.losses.Reduction.NONE)
                zero_loss_all = tf.zeros(shape=[tf.shape(ce_loss_all)[0]])
                if if_last_output:
                    return ce_loss_all
                else:
                    return tf.where(condition=condition, x=ce_loss_all, y=zero_loss_all)

        def tf_score_diff(win_output, target_diff, condition, if_last_output):
            with tf.variable_scope('mean_difference'):
                square_diff_loss_all = tf.square(target_diff - win_output)
                abs_diff_loss_all = tf.abs(target_diff - win_output)
                zero_loss_all = tf.zeros(shape=[tf.shape(square_diff_loss_all)[0]])
                if if_last_output:
                    return square_diff_loss_all, abs_diff_loss_all
                else:
                    return tf.where(condition=condition, x=square_diff_loss_all, y=zero_loss_all), \
                           tf.where(condition=condition, x=abs_diff_loss_all, y=zero_loss_all)

        def tf_caernn_cross_entropy(target_x, dec_x, condition):
            with tf.variable_scope('cross_entropy'):
                ce_loss_all = tf.losses.softmax_cross_entropy(onehot_labels=target_x,
                                                              logits=dec_x, reduction=tf.losses.Reduction.NONE)

                zero_loss_all = tf.zeros(shape=[tf.shape(ce_loss_all)[0]])
                return tf.where(condition=condition, x=ce_loss_all, y=zero_loss_all)

        def get_caernn_lossfunc(x_recon, target_x, condition):
            likelihood_loss = tf_caernn_cross_entropy(dec_x=x_recon, target_x=target_x, condition=condition)

            # kl_loss = tf.zeros(shape=[tf.shape(kl_loss)[0]])  # TODO: why if we only optimize likelihood_loss
            return likelihood_loss

        def get_diff_lossfunc(diff_output, diff_target_ph, condition, if_last_output):
            square_diff_loss, abs_diff_loss = tf_score_diff(diff_output, diff_target_ph, condition, if_last_output)
            return square_diff_loss, abs_diff_loss

        def get_action_pred_lossfunc(action_pred_output, action_pred_target_ph, condition, if_last_output):
            action_pred_loss = tf_cross_entropy(action_pred_output, action_pred_target_ph, condition, if_last_output)
            return action_pred_loss

        batch_size = tf.shape(self.input_data_ph)[0]
        with tf.variable_scope('caernn'):

            self.cell = AutoEncoderRNNCell(config=self.config)

            self.initial_state_c, self.initial_state_h = self.cell.zero_state(
                batch_size=tf.shape(self.input_data_ph)[0],
                dtype=tf.float32)

            flat_target_data = tf.reshape(self.target_data_ph, [-1, self.config.Arch.CAERNN.x_dim])
            caernn_outputs, last_state = tf.nn.dynamic_rnn(cell=self.cell, inputs=self.input_data_ph,
                                                           sequence_length=self.trace_length_ph,
                                                           initial_state=tf.contrib.rnn.LSTMStateTuple(
                                                               self.initial_state_c,
                                                               self.initial_state_h))

        caernn_outputs = tf.split(value=tf.transpose(a=caernn_outputs, perm=[1, 0, 2]),
                                  num_or_size_splits=[1] * self.config.Learn.max_seq_length, axis=0)
        outputs_reshape = []
        outputs_all = []
        for output in caernn_outputs:
            output = tf.squeeze(output, axis=0)
            output = tf.split(value=output, num_or_size_splits=self.cell.output_dim_list, axis=1)
            outputs_all.append(output)

        for n, name in enumerate(self.cell_output_names):
            with tf.variable_scope(name):
                x = tf.stack([o[n] for o in outputs_all])
                x = tf.transpose(x, [1, 0, 2])
                x = tf.reshape(x, [batch_size * self.config.Learn.max_seq_length, self.cell_output_dim_list[n]])
                outputs_reshape.append(x)

        [self.x_recon, z_embedding] = outputs_reshape
        self.player_embedding = z_embedding
        embed_shape = [batch_size, self.config.Learn.max_seq_length,
                       self.config.Arch.CAERNN.latent_dim]
        self.select_index = tf.range(0, batch_size) * self.config.Learn.max_seq_length + (self.trace_length_ph - 1)
        # self.z_encoder_output = tf.gather(tf.concat([self.dec_mu_0, self.dec_mu_1], axis=1), self.select_index)
        self.z_encoder_output = tf.gather(z_embedding, self.select_index)
        self.final_state_c, self.final_state_h = last_state
        condition = tf.cast(tf.reshape(self.selection_matrix_ph,
                                       shape=[tf.shape(self.selection_matrix_ph)[0] *
                                              tf.shape(self.selection_matrix_ph)[1]]), tf.bool)

        self.output = tf.reshape(tf.nn.softmax(self.x_recon),
                                 shape=[batch_size, tf.shape(self.input_data_ph)[1], -1])

        likelihood_loss = get_caernn_lossfunc(self.x_recon, flat_target_data, condition)

        with tf.variable_scope('caernn_cost'):
            self.likelihood_loss = tf.reshape(likelihood_loss, shape=[batch_size, self.config.Learn.max_seq_length, -1])

        tvars_caernn = tf.trainable_variables(scope='caernn')
        for t in tvars_caernn:
            print ('caernn_var: ' + str(t.name))
        caernn_grads = tf.gradients(tf.reduce_mean(self.likelihood_loss), tvars_caernn)
        ll_grads = tf.gradients(tf.reduce_mean(self.likelihood_loss), tvars_caernn)
        # grads = tf.cond(
        #    tf.global_norm(grads) > 1e-20,
        #    lambda: tf.clip_by_global_norm(grads, args.grad_clip)[0],
        #    lambda: grads)
        optimizer = tf.train.AdamOptimizer(self.config.Learn.learning_rate)
        self.train_ll_op = optimizer.apply_gradients(zip(ll_grads, tvars_caernn))
        self.train_general_op = optimizer.apply_gradients(zip(caernn_grads, tvars_caernn))
        # self.saver = tf.train.Saver(tf.all_variables())


        if self.win_score_diff:
            with tf.variable_scope('score_diff'):
                data_input_action_pred = self.input_data_ph[:, :,
                                         self.config.Arch.CAERNN.x_dim:self.config.Arch.CAERNN.y_dim +
                                                                       self.config.Arch.CAERNN.x_dim]

                z_encoder_score_diff = tf.reshape(self.player_embedding, shape=embed_shape)
                for i in range(self.config.Arch.WIN.lstm_layer_num):
                    rnn_output = None
                    for i in range(self.config.Arch.WIN.lstm_layer_num):
                        rnn_input = tf.concat([data_input_action_pred, z_encoder_score_diff],
                                              axis=2) if i == 0 else rnn_output
                        rnn_output, rnn_state = tf.nn.dynamic_rnn(  # while loop dynamic learning rnn
                            inputs=rnn_input, cell=self.score_diff_lstm_cell[i],
                            sequence_length=self.trace_length_ph, dtype=tf.float32,
                            scope='score_diff_rnn_{0}'.format(str(i)))
                    action_pred_rnn_outputs = tf.stack(rnn_output)
                    # Indexing
                    score_diff_rnn_last = tf.gather(tf.reshape(action_pred_rnn_outputs,
                                                               [-1, self.config.Arch.SARSA.h_size]), self.select_index)

                for j in range(self.config.Arch.WIN.dense_layer_number - 1):
                    score_diff_input = score_diff_rnn_last if j == 0 else score_diff_output
                    score_diff_output = tf.nn.relu(
                        linear(score_diff_input, output_size=self.config.Arch.WIN.dense_layer_size,
                               scope='win_dense_Linear'))
                score_diff_input = score_diff_rnn_last if self.config.Arch.WIN.dense_layer_number == 1 else score_diff_output
                score_diff_output = linear(score_diff_input, output_size=3, scope='score_diff')
                # self.diff_output = tf.nn.softmax(diff_output)
                self.diff_output = score_diff_output

                with tf.variable_scope('score_diff_cost'):
                    square_diff_loss, abs_diff_loss = get_diff_lossfunc(self.diff_output, self.score_diff_target_ph,
                                                                        condition,
                                                                        if_last_output=True)
                    self.diff_loss = square_diff_loss
                    self.diff = abs_diff_loss
            if self.config.Learn.integral_update_flag:
                tvars_score_diff = tf.trainable_variables()
            else:
                tvars_score_diff = tf.trainable_variables(scope='score_diff')
            for t in tvars_score_diff:
                print ('tvars_score_diff: ' + str(t.name))
            score_diff_grads = tf.gradients(tf.reduce_mean(self.diff_loss), tvars_score_diff)
            self.train_diff_op = optimizer.apply_gradients(zip(score_diff_grads, tvars_score_diff))

        if self.extra_prediction_flag:
            with tf.variable_scope('prediction'):
                data_input_action_pred = self.input_data_ph[
                                         :, :,
                                         self.config.Arch.CAERNN.x_dim:self.config.Arch.CAERNN.y_dim +
                                                                       self.config.Arch.CAERNN.x_dim]

                z_encoder_action_pred = tf.reshape(self.player_embedding, shape=embed_shape)
                for i in range(self.config.Arch.Predict.lstm_layer_num):
                    rnn_output = None
                    for i in range(self.config.Arch.Predict.lstm_layer_num):
                        rnn_input = tf.concat([data_input_action_pred, z_encoder_action_pred],
                                              axis=2) if i == 0 else rnn_output
                        rnn_output, rnn_state = tf.nn.dynamic_rnn(  # while loop dynamic learning rnn
                            inputs=rnn_input, cell=self.action_lstm_cell[i],
                            sequence_length=self.trace_length_ph, dtype=tf.float32,
                            scope='action_pred_rnn_{0}'.format(str(i)))
                    action_pred_rnn_outputs = tf.stack(rnn_output)
                    # Indexing
                    action_pred_rnn_last = tf.gather(tf.reshape(action_pred_rnn_outputs,
                                                                [-1, self.config.Arch.SARSA.h_size]), self.select_index)

                for j in range(self.config.Arch.Predict.dense_layer_number - 1):
                    action_pred_input = action_pred_rnn_last if j == 0 else action_pred_output
                    action_pred_output = tf.nn.relu(linear(action_pred_input,
                                                           output_size=self.config.Arch.Predict.dense_layer_size,
                                                           scope='action_dense_Linear'))
                action_pred_input = action_pred_rnn_last if self.config.Arch.Predict.dense_layer_number == 1 else action_pred_output
                action_pred_output = linear(action_pred_input, output_size=self.config.Arch.Predict.output_size,
                                            scope='action_next')
                self.action_pred_output = tf.nn.softmax(action_pred_output)

                with tf.variable_scope('action_pred_cost'):
                    action_pred_loss = get_action_pred_lossfunc(self.action_pred_output,
                                                                self.action_pred_target_ph,
                                                                condition,
                                                                if_last_output=True)
                    self.action_pred_loss = action_pred_loss
            if self.config.Learn.integral_update_flag:
                tvars_action_pred = tf.trainable_variables()

            else:
                tvars_action_pred = tf.trainable_variables(scope='prediction')
            for t in tvars_action_pred:
                print ('tvars_action_pred: ' + str(t.name))
            action_grads = tf.gradients(tf.reduce_mean(self.action_pred_loss), tvars_action_pred)
            self.train_action_pred_op = optimizer.apply_gradients(zip(action_grads, tvars_action_pred))


if __name__ == '__main__':
    caernn_config_path = "../environment_settings/icehockey_caernn_PlayerLocalId_predict_nex_goal_config.yaml"
    caernn_config = CAERNNCongfig.load(caernn_config_path)
    caernn = CAERNN(config=caernn_config, extra_prediction_flag=True)
    caernn()
    print ('testing')
